fix pi0 prepare_language will raise an error if the task is a string#1625
fix pi0 prepare_language will raise an error if the task is a string#1625captainfffsama wants to merge 1 commit intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR fixes a bug in the Pi0 model where string-based task inputs during inference cause tokenization errors. The fix ensures string tasks are properly wrapped in a list and replicated to match the batch size.
- Adds type checking to handle single string task inputs
- Implements batch size replication for single task strings
- Aligns Pi0 behavior with smol-vla model handling
| if isinstance(tasks, str): | ||
| tasks = [tasks] | ||
|
|
||
| if len(tasks) == 1: |
There was a problem hiding this comment.
The condition len(tasks) == 1 will always be true after the previous block that converts a string to a single-element list. This means single tasks will always be replicated to match batch size, even when the original input was already a list with one element that shouldn't be replicated.
| if isinstance(tasks, str): | |
| tasks = [tasks] | |
| if len(tasks) == 1: | |
| was_string = isinstance(tasks, str) # Track if the input was originally a string | |
| if was_string: | |
| tasks = [tasks] | |
| if was_string and len(tasks) == 1: # Only replicate if the input was originally a string |
| if isinstance(tasks, str): | ||
| tasks = [tasks] | ||
|
|
||
| if len(tasks) == 1: |
There was a problem hiding this comment.
The logic assumes that when len(tasks) == 1, it should replicate the task across the batch dimension. However, this doesn't distinguish between a single string input (which should be replicated) and a legitimate single-element list (which may not need replication). Consider checking the original input type or batch size mismatch instead.
| if isinstance(tasks, str): | |
| tasks = [tasks] | |
| if len(tasks) == 1: | |
| was_string = isinstance(tasks, str) # Track if the input was originally a string | |
| if was_string: | |
| tasks = [tasks] | |
| if was_string or len(tasks) == 1 and len(tasks) != batch[OBS_STATE].shape[0]: |
|
hey @captainfffsama, this issue will be fixed when #1431 #1452 will be merged 😄 |
Fix: Handle string-based task inputs during Pi0 inference
Description:
When the task input is a string during inference with record.py, it was incorrectly split, leading to tokenization errors. This PR wraps the string in a list to ensure correct processing, aligning the behavior with smol-vla.